{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4475c302",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import transforms\n",
    "from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n",
    "from timm.data import create_transform\n",
    "from timm.data.transforms import _pil_interp\n",
    "from classification.focalnet import FocalNet, build_transforms, build_transforms4display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f871df8",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "build model\n",
    "'''\n",
    "img_size = 224\n",
    "\n",
    "# multi-scale FocalNets\n",
    "model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3]).cuda()\n",
    "\n",
    "# isotropic FocalNets\n",
    "# model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True,).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "00969042",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "build data transform\n",
    "'''\n",
    "eval_transforms = build_transforms(img_size, center_crop=False)\n",
    "display_transforms = build_transforms4display(img_size, center_crop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27994844",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "load checkpoint\n",
    "'''\n",
    "ckpt_path = \"focalnet_base_lrf.pth\"\n",
    "ckpt = torch.load(ckpt_path)\n",
    "model.load_state_dict(ckpt['model']).eva()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01dbc82e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9953da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize modulator \n",
    "upsampler = nn.Upsample(scale_factor=4, mode='bilinear')\n",
    "\n",
    "img_folder = \"./figures\"\n",
    "img_paths = os.listdir(img_folder)\n",
    "for img_path in img_paths:\n",
    "    img = Image.open(img_folder + img_path)\n",
    "    img_t = eval_transforms(img) \n",
    "    img_d = display_transforms(img)\n",
    "    out = model(img_t.unsqueeze(0).cuda())    \n",
    "\n",
    "    fig=plt.figure(figsize=(16, 8))\n",
    "    \n",
    "    fig.add_subplot(1, 2, 1)       \n",
    "    img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()\n",
    "    x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     \n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)    \n",
    "    plt.subplots_adjust(wspace=None, hspace=None)\n",
    "\n",
    "    fig.add_subplot(1, 2, 2)    \n",
    "    modulator = torch.abs((model.layers[layer].blocks[-1].modulation.modulator)).mean(1, keepdim=True)\n",
    "    modulator = upsampler(modulator)\n",
    "    x = plt.imshow((modulator.squeeze(1)).permute(1, 2, 0).cpu().detach().contiguous().numpy())    \n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)    \n",
    "\n",
    "    plt.subplots_adjust(wspace=0, hspace=0)    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60988461",
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize gating maps \n",
    "upsampler = nn.Upsample(scale_factor=4, mode='bilinear')\n",
    "\n",
    "img_folder = \"./figures\"\n",
    "img_paths = os.listdir(img_folder)\n",
    "for img_path in img_paths:\n",
    "    img = Image.open(img_folder + img_path)\n",
    "    img_t = eval_transforms(img) \n",
    "    img_d = display_transforms(img)\n",
    "    out = model(img_t.unsqueeze(0).cuda())    \n",
    "\n",
    "    fig=plt.figure(figsize=(16, 8))\n",
    "    \n",
    "    fig.add_subplot(1, 5, 1)       \n",
    "    img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()\n",
    "    x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     \n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)    \n",
    "\n",
    "    gates = (model.layers[-1].blocks[-1].modulation.gates)\n",
    "    for i in range(4):\n",
    "        fig.add_subplot(1, 5, i+2)        \n",
    "        gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()\n",
    "        plt.imshow(gates_i.permute(1,2,0).numpy())        \n",
    "        plt.axis('off')\n",
    "        x.axes.get_xaxis().set_visible(False)\n",
    "        x.axes.get_yaxis().set_visible(False)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79ad4638",
   "metadata": {},
   "outputs": [],
   "source": [
    "# display learned focal kernel weights\n",
    "\n",
    "fig=plt.figure(figsize=(8, 8))  \n",
    "for id_layer in range(4):\n",
    "    fig.add_subplot(1, 4, id_layer+1)\n",
    "    x = plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[0][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())    \n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)\n",
    "    \n",
    "fig=plt.figure(figsize=(8, 8))   \n",
    "for id_layer in range(4):\n",
    "    fig.add_subplot(1, 4, id_layer+1)\n",
    "    plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[1][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())\n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)\n",
    "    \n",
    "fig=plt.figure(figsize=(8, 8))   \n",
    "for id_layer in range(4):\n",
    "    fig.add_subplot(1, 4, id_layer+1)\n",
    "    plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[2][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())    \n",
    "    plt.axis('off')\n",
    "    x.axes.get_xaxis().set_visible(False)\n",
    "    x.axes.get_yaxis().set_visible(False)    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
